'''
Author: hjie huangjie20011001@163.com
Date: 2024-12-13 14:44:38
'''
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import _LRScheduler

class GradualWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, total_epochs, warmup_epochs, last_epoch=-1):
        if not isinstance(optimizer, optim.Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.total_epochs = total_epochs
        self.warmup_epochs = warmup_epochs
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]

        super(GradualWarmupScheduler, self).__init__(optimizer, last_epoch)
        
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            return [(self.base_lrs[i] * self.last_epoch / self.warmup_epochs) 
                    for i in range(len(self.base_lrs))]
        return [self.base_lrs[i] for i in range(len(self.base_lrs))]

def decay_lr_poly(base_lr, epoch_i, batch_i, total_epochs, total_batches, warm_up, power=1.0, way= 'linear', stable_epochs: int=1):
    '''
    base_lr: 初始化学习率
    epoch_i, batch_i, total_epochs, total_batches: 第i次epoch/batch;总共的epoch/batch
    warm_up: 训练到多少轮次，才变成开始设定学习率
    '''
    def cos_decay(batch_i, epoch_i, total_batches, total_epochs, warm_up):
        progress = (epoch_i - warm_up) * total_batches + batch_i
        cosine_decay = 0.5 * (1.0 + np.cos(np.pi * progress / ((total_epochs - warm_up) * total_batches)))
        return cosine_decay
    def linear_decay(batch_i, epoch_i, total_batches, total_epochs, warm_up):
        rate = np.power(
            1.0 - ((epoch_i - warm_up) * total_batches + batch_i) / ((total_epochs - warm_up) * total_batches),
            power)
        return rate
    def stable_decay(batch_i, epoch_i, total_batches, total_epochs, warm_up, stable_epochs):
        if epoch_i < warm_up + stable_epochs:
            rate = 1.0
        else:
            # Decay阶段
            progress = (epoch_i - warm_up - stable_epochs) * total_batches + batch_i
            rate = 0.5 * (1.0 + np.cos(np.pi * progress / ((total_epochs - warm_up - stable_epochs) * total_batches)))
        return rate

    decay_way = {'linear': linear_decay(batch_i, epoch_i, total_batches, total_epochs, warm_up),
                 'cos': cos_decay(batch_i, epoch_i, total_batches, total_epochs, warm_up),
                 'stable': stable_decay(batch_i, epoch_i, total_batches, total_epochs, warm_up, stable_epochs)}
    assert way in decay_way, 'ERROR: MUST CHOOSE WAY IN [linear, cos, stable]'
    if warm_up > 0 and epoch_i < warm_up:
        rate = (epoch_i * total_batches + batch_i) / (warm_up * total_batches)
    else:
        rate = decay_way[way]
    return rate * base_lr

if __name__ == '__main__':
    class LinearModel(nn.Module):
        def __init__(self):
            super(LinearModel, self).__init__()
            self.linear = nn.Linear(10, 1)

        def forward(self, x):
            return self.linear(x)
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.randn(100, 10).to(device)
    y = torch.randn(100, 1).to(device)

    dataset = TensorDataset(x, y)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

    model = LinearModel().to(device)
    loss_fn = nn.MSELoss()
    total_epochs, warmup_epochs, lr = 100, 20, 0.01
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # 方法2
    scheduler = GradualWarmupScheduler(optimizer, total_epochs, warmup_epochs)
   
    for epoch in range(total_epochs):
        model.train()
        for batch_idx, (input, target) in enumerate(dataloader):
            # 方法1
            batch_lr = decay_lr_poly(lr, epoch, batch_idx, total_epochs, len(dataloader), 
                                     warmup_epochs, 1, 'stable', stable_epochs= 30)
            for group in optimizer.param_groups:
                group['lr'] = batch_lr
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
        # scheduler.step()
        # lr = scheduler.get_lr()[0]
        # print(f'Epoch {epoch+1}, Learning Rate: {lr:.6f}, Loss Value: {loss.item()}')
        print(f'Epoch {epoch+1}, Learning Rate: {batch_lr:.6f}, Loss Value: {loss.item()}')